-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix qwen encoder hidden states mask #12655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Improves attention mask handling for QwenImage transformer by: - Adding support for variable-length sequence masking - Implementing dynamic attention mask generation from encoder_hidden_states_mask - Ensuring RoPE embedding works correctly with padded sequences - Adding comprehensive test coverage for masked input scenarios Performance and flexibility benefits: - Enables more efficient processing of sequences with padding - Prevents padding tokens from contributing to attention computations - Maintains model performance with minimal overhead
Improves file naming convention for the Qwen image mask performance benchmark script Enhances code organization by using a more descriptive and consistent filename that clearly indicates the script's purpose
|
@cdutr it's great that you have also included the benchmarking script for fullest transparency. But we can remove that from this PR and instead have that as a GitHub gist. The benchmark numbers make sense to me. Some comments:
Also, I think a natural next step would be see how well this performs when combined with FA varlen. WDYT? @naykun what do you think about the changes? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @sayakpaul! I removed the benchmark script, moved all tests to this gist. torch.compile testAlso tested the performance with Tested on NVIDIA A100 80GB PCIe: Also validated on RTX 4050 6GB (laptop) with similar results (2.38x speedup). The mask implementation is fully compatible with torch.compile. Image outputsTested End-to-end image generation: Successfully generated images using QwenImagePipeline and pipeline runs without errors, here is the output generated:
FA VarlenFA varlen is the natural next step, yes! I'm interested in working on it. Should I keep iterating in this PR, or should we merge it and create a new issue? The mask infrastructure this PR adds would translate well to varlen, instead of masking padding tokens, we'd pack only valid tokens using the same sequence length information |
|
Thanks for the results! Looks quite nice.
I think it's fine to first merge this PR and then we work on it afterwards. We're adding easier support for Sage and FA2 in this PR: #12439, so after that's merged, it will be quite easy to work on that (thanks to the Could we also check if the outputs deviate with and without the masks, i.e., the outputs we get on |
|
@dxqb would you maybe interested in checking this PR out as well? |
| ) | ||
|
|
||
| joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1) | ||
| attention_mask = joint_attention_mask_1d[:, None, None, :] * joint_attention_mask_1d[:, None, :, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works, but optimization possible:
this generates a real 2D mask in memory, of size seq_len * seq_len. An attention mask that is broadcastable to the required shape is enough:
attention_mask_2d = attention_mask[:, None, None, :]
Why this is enough: all tokens don't attend to the masked tokens anymore. whether the masked tokens attend to any other tokens is irrelevant, because they are masked in all layers and their result is never used.
I have tested this and the results were pixel-identical.
| ) | ||
|
|
||
| joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1) | ||
| attention_mask = joint_attention_mask_1d[:, None, None, :] * joint_attention_mask_1d[:, None, :, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another optimization:
attention masking is expensive, because torch SDPA switches to a flash attention algorithm internally if there is no attention mask. it cannot do that with an attention mask.
detecting a no-op attention mask can help:
attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None
but you could also say that you expect the caller to not pass an attention mask if it's a no-op. also valid viewpoint.
| f"must match encoder_hidden_states sequence length ({text_seq_len})" | ||
| ) | ||
|
|
||
| text_attention_mask = encoder_hidden_states_mask.bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works if the encoder_hidden_states_mask is already bool, or a float tensor with the same semantics.
bool attention masks are enough for the usual usecase of masking unused text tokens, but if only bool attention masks are supported this should be clearly documented. also maybe change the type hint?
see https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html how float attention masks are interpreted by torch. a float 0.0 is not masked, a bool False is masked.
there are some usecases for float attention masks for text sequences, like putting an emphasis/bias on certain tokens. not very common though, so if you decide to only support bool attention masks that makes sense to me - but requires documentation.
| # Use padded sequence length for RoPE when mask is present. | ||
| # The attention mask will handle excluding padding tokens. | ||
| if encoder_hidden_states_mask is not None: | ||
| txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please read this #12344 (comment)
I don't think this is how the txt_seq_lens parameter was intended.
However, your change here might still be a valid (temporary) fix, because it's currently (before the PR) not used as intended either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just noticed, this was already discussed in other comments above.

What does this PR do?
Fixes the QwenImage encoder to properly apply
encoder_hidden_states_maskwhen passed to the model. Previously, the mask parameter was accepted but ignored, causing padding tokens to incorrectly influence attention computation.Changes
QwenDoubleStreamAttnProcessor2_0to create a 2D attention mask from the 1Dencoder_hidden_states_mask, properly masking text padding tokens while keeping all image tokens unmaskedImpact
This fix enables proper Classifier-Free Guidance (CFG) batching with variable-length text sequences, which is common when batching conditional and unconditional prompts together.
Benchmark Results
Overhead: +2.8% for mask processing without padding, +18.7% with actual padding (realistic CFG scenario)
The higher overhead with padding is expected and acceptable as it represents the cost of properly handling variable-length sequences in batched inference. This is a necessary correctness fix rather than an optimization. Test ran on RTX 4070 12GB.
Fixes #12294
Before submitting
Who can review?
@yiyixuxu @sayakpaul - Would appreciate your review, especially regarding the benchmarking approach. I used a custom benchmark rather than
BenchmarkMixinbecause:Note: The benchmark file is named
benchmarking_qwenimage_mask.py(with "benchmarking" prefix) rather thanbenchmark_qwenimage_mask.pyto prevent it from being picked up byrun_all.py, since it doesn't useBenchmarkMixinand produces a different CSV schema. If you prefer, I can adapt it to use the standard format instead.Happy to adjust the approach if you have suggestions!